{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "74b23c92",
   "metadata": {},
   "source": [
    "# Assess predictions on text classification covid 19 data with a huggingface transformers model"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "5d090f3b",
   "metadata": {},
   "source": [
    "This notebook demonstrates the use of the `responsibleai` API to assess a text classification huggingface transformers model trained on the covid 19 events dataset (see https://huggingface.co/datasets/joelito/covid19_emergency_event for more information about the dataset). It walks through the API calls necessary to create a widget with model analysis insights, then guides a visual analysis of the model."
   ]
  },
  {
   "cell_type": "markdown",
   "id": "673e3b88",
   "metadata": {},
   "source": [
    "* [Launch Responsible AI Toolbox](#Launch-Responsible-AI-Toolbox)\n",
    "    * [Load Model and Data](#Load-Model-and-Data)\n",
    "    * [Create Model and Data Insights](#Create-Model-and-Data-Insights)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "b87b1db5",
   "metadata": {},
   "source": [
    "## Launch Responsible AI Toolbox"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "be74708c",
   "metadata": {},
   "source": [
    "The following section examines the code necessary to create datasets and a model. It then generates insights using the `responsibleai` API that can be visually analyzed."
   ]
  },
  {
   "cell_type": "markdown",
   "id": "7af634c6",
   "metadata": {},
   "source": [
    "### Load Model and Data\n",
    "*The following section can be skipped. It loads a dataset and trains a model for illustrative purposes.*"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "1c4adedd",
   "metadata": {},
   "source": [
    "First we import all necessary dependencies"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "75ef9e91",
   "metadata": {},
   "outputs": [],
   "source": [
    "import datasets\n",
    "import pandas as pd\n",
    "import numpy as np\n",
    "import zipfile\n",
    "from transformers import (AutoModelForSequenceClassification, AutoTokenizer,\n",
    "                          pipeline)\n",
    "\n",
    "from raiutils.common.retries import retry_function\n",
    "from pathlib import Path\n",
    "\n",
    "try:\n",
    "    from urllib import urlretrieve\n",
    "except ImportError:\n",
    "    from urllib.request import urlretrieve"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "c7ba11f5",
   "metadata": {},
   "source": [
    "Next we load the covid 19 events dataset from huggingface datasets"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "3ae1bf09",
   "metadata": {},
   "outputs": [],
   "source": [
    "NUM_TEST_SAMPLES = 100\n",
    "\n",
    "def load_dataset(split):\n",
    "    dataset = datasets.load_dataset(\"joelito/covid19_emergency_event\", split=split)\n",
    "    return pd.DataFrame({\"language\": dataset[\"language\"],\n",
    "                         \"text\": dataset[\"text\"],\n",
    "                         \"event1\": dataset[\"event1\"],\n",
    "                         \"event2\": dataset[\"event2\"],\n",
    "                         \"event3\": dataset[\"event3\"],\n",
    "                         \"event4\": dataset[\"event4\"],\n",
    "                         \"event5\": dataset[\"event5\"],\n",
    "                         \"event6\": dataset[\"event6\"],\n",
    "                         \"event7\": dataset[\"event7\"],\n",
    "                         \"event8\": dataset[\"event8\"]})\n",
    "\n",
    "def select_english_subset(dataset):\n",
    "    # select only English subset\n",
    "    dataset = dataset[dataset.language == \"en\"].reset_index(drop=True)\n",
    "    dataset = dataset.drop(columns=\"language\")\n",
    "    return dataset\n",
    "\n",
    "pd_data = load_dataset(\"train\")\n",
    "pd_data = select_english_subset(pd_data)\n",
    "pd_valid_data = load_dataset(\"test\")\n",
    "pd_valid_data = select_english_subset(pd_valid_data)\n",
    "train_data = pd_data\n",
    "test_data = pd_valid_data"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "9de3b558",
   "metadata": {},
   "source": [
    "Fetch a pre-trained huggingface model on the DBPedia dataset"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "36e20709",
   "metadata": {},
   "outputs": [],
   "source": [
    "COVID19_EVENTS_MODEL_NAME = \"covid19_events_model\"\n",
    "NUM_LABELS = 8\n",
    "\n",
    "labels = [\"event1\", \"event2\", \"event3\", \"event4\", \"event5\", \"event6\", \"event7\", \"event8\"]\n",
    "\n",
    "id2label = {idx:label for idx, label in enumerate(labels)}\n",
    "label2id = {label:idx for idx, label in enumerate(labels)}\n",
    "\n",
    "class FetchModel(object):\n",
    "    def __init__(self):\n",
    "        pass\n",
    "\n",
    "    def fetch(self):\n",
    "        zipfilename = COVID19_EVENTS_MODEL_NAME + '.zip'\n",
    "        if not Path(zipfilename).exists():\n",
    "            url = ('https://publictestdatasets.blob.core.windows.net/models/' +\n",
    "                   COVID19_EVENTS_MODEL_NAME + '.zip')\n",
    "            urlretrieve(url, zipfilename)\n",
    "        with zipfile.ZipFile(zipfilename, 'r') as unzip:\n",
    "            unzip.extractall(COVID19_EVENTS_MODEL_NAME)\n",
    "\n",
    "def retrieve_covid19_events_model():\n",
    "    fetcher = FetchModel()\n",
    "    action_name = \"Model download\"\n",
    "    err_msg = \"Failed to download model\"\n",
    "    max_retries = 4\n",
    "    retry_delay = 60\n",
    "    retry_function(fetcher.fetch, action_name, err_msg,\n",
    "                   max_retries=max_retries,\n",
    "                   retry_delay=retry_delay)\n",
    "    model = AutoModelForSequenceClassification.from_pretrained(\n",
    "        COVID19_EVENTS_MODEL_NAME, num_labels=NUM_LABELS,\n",
    "        problem_type=\"multi_label_classification\",\n",
    "        id2label=id2label,\n",
    "        label2id=label2id)\n",
    "    return model\n",
    "\n",
    "model = retrieve_covid19_events_model()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "f3ce1bfd",
   "metadata": {},
   "source": [
    "Load the model and tokenizer"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c93f3d7e",
   "metadata": {},
   "outputs": [],
   "source": [
    "# load the model and tokenizer\n",
    "tokenizer = AutoTokenizer.from_pretrained(\"bert-base-uncased\")\n",
    "\n",
    "device = -1\n",
    "if device >= 0:\n",
    "    model = model.cuda()\n",
    "\n",
    "# build a pipeline object to do predictions\n",
    "pred = pipeline(\n",
    "    \"text-classification\",\n",
    "    model=model,\n",
    "    tokenizer=tokenizer,\n",
    "    device=device,\n",
    "    return_all_scores=True\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "ff755806",
   "metadata": {},
   "source": [
    "### Create Model and Data Insights"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "3b4edc42",
   "metadata": {},
   "outputs": [],
   "source": [
    "from responsibleai_text import RAITextInsights, ModelTask\n",
    "from raiwidgets import ResponsibleAIDashboard"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "ab3737a7",
   "metadata": {},
   "source": [
    "To use Responsible AI Dashboard, initialize a RAITextInsights object upon which different components can be loaded.\n",
    "\n",
    "RAITextInsights accepts the model, the test dataset, the classes and the task type as its arguments."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "cc8021d3",
   "metadata": {},
   "outputs": [],
   "source": [
    "rai_insights = RAITextInsights(pred, test_data[:3],\n",
    "                               labels,\n",
    "                               task_type=ModelTask.MULTILABEL_TEXT_CLASSIFICATION)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "87ba5b5c",
   "metadata": {},
   "source": [
    "Add the components of the toolbox for model assessment."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b0fa47fd",
   "metadata": {},
   "outputs": [],
   "source": [
    "rai_insights.error_analysis.add()\n",
    "rai_insights.explainer.add()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "62612ce7",
   "metadata": {},
   "source": [
    "Once all the desired components have been loaded, compute insights on the test set."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b94326d8",
   "metadata": {},
   "outputs": [],
   "source": [
    "rai_insights.compute()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "23c0e75e",
   "metadata": {},
   "source": [
    "Finally, visualize and explore the model insights. Use the resulting widget or follow the link to view this in a new tab."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b63dbfa2",
   "metadata": {},
   "outputs": [],
   "source": [
    "ResponsibleAIDashboard(rai_insights)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.15"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
